import socket
import itertools
import hjson
import os 
import scalevi.utils.utils as utils

###############################################################################
# Utilities: Server
###############################################################################
RECOGNIZED_SERVERS=["swarm", "gypsum"]
RECOGNIZED_WORKERS=["swarm", "node"]

def on_server():
    return any(map(socket.gethostname().startswith, RECOGNIZED_SERVERS)) 


def which_worker():
    hostname = socket.gethostname()
    if hostname.startswith('swarm'):
        return "swarm" 
    elif hostname.startswith('node'):
        return "node" 
    else:
        return hostname 


def which_server():
    hostname = socket.gethostname()
    if hostname.startswith('swarm'):
        return "swarm" 
    elif hostname.startswith('gypsum'):
        return "gypsum" 
    else:
        return hostname




###############################################################################
# Utilities: Different hyper-parameter permutations
###############################################################################

def get_all_permutations(config_dict):

    # Rules
    # 1. hyper_params_to_iter is important to control the final None thing; thus all the iterates common and value-based, should be in this list
    # 2. Value based works with or without keys
    # 3. A permutation element can be any valid element (compatible with .hjson); the outer list is treated as the batch/range 

    # create all permutations for value based and common iterates
    iterate_pairs = list(
            itertools.product(
                *[config_dict[val_iterate+"_range"] 
                    for val_iterate in (config_dict['value_based_iterates']
                                + config_dict['common_iterates'])]))
    permutation_dicts = []
    for i, perm in enumerate(iterate_pairs):

        # take one permutation of value based and common iterates
        permutation_dict = {
            hparam: value 
            for hparam, value in zip(
                                    config_dict['value_based_iterates']
                                    + config_dict['common_iterates'],
                                    perm)}
        value_based_hparam_values = []
        value_based_hparam_names = []
        for hparam in config_dict['value_based_iterates']:
            name = ((config_dict[hparam+"_keys"][permutation_dict[hparam]] 
                    if permutation_dict[hparam] in config_dict[hparam+"_keys"] 
                    else 
                    permutation_dict[hparam]) 
                    if (hparam+"_keys" in config_dict) 
                    else 
                    permutation_dict[hparam])
            # collect the values and names of all the value based hyper parameter permutations 
            value_based_hparam_values.extend([config_dict[key]
                                            for key in 
                                            config_dict.keys()
                                            if (key.startswith(name))
                                                & (key.endswith("range"))
                                            ])
            value_based_hparam_names.extend([
                                            key[:-6] 
                                            for key in 
                                            config_dict.keys() 
                                            if ((key.startswith(name))
                                            & (key.endswith("range"))) 
                                            ])

        # put all the value based hparams in a single list of lists
        value_based_hparam_values = list(
                                    itertools.product(
                                        *value_based_hparam_values))
        
        # create different dictionaries based on the value based permutations by combining the two dictionaries
        permutation_dicts.extend([
            {   **{hparam: value 
                    for hparam, value in
                        zip(value_based_hparam_names, values)}, 
                **permutation_dict 
            }
            for values in value_based_hparam_values])

    # set every other hparam value to None
    [p.update({iterate: None}) 
     for i, p in enumerate(permutation_dicts) 
        for iterate in config_dict["hyper_params_to_iter"] 
            if iterate not in p.keys()]
    return permutation_dicts


def config_to_skip(config_dict):
    def ignore_rest_of_combinations_when(when, combinations):
        if config_dict[when['key']] == when['value']:
            for combination in combinations:
                if combination+"_range" in config_dict:
                    if config_dict[combination] != config_dict[combination+"_range"][0]:
                        return True 
    if config_dict['minibatch_size']>config_dict['N_leaves']:
        return True
    if ignore_rest_of_combinations_when(
                        {
                            "key": "var_dist", 
                            "value": "BranchGaussianWithSampleEval"
                        }, 
                        ['encoder_encode_θ']):
        return True

    if ignore_rest_of_combinations_when(
                        {
                            "key": "var_dist", 
                            "value": "GaussianWithSampleEval"
                        }, 
                        ['N_leaves']):
        return True

    if ignore_rest_of_combinations_when(
                        {
                            "key": "var_dist", 
                            "value": "DiagonalWithSampleEval"
                        }, 
                        ['N_leaves']):
        return True

    if ignore_rest_of_combinations_when(
                        {
                            "key": "var_dist", 
                            "value": "BlockGaussianWithSampleEval"
                        }, 
                        ['N_leaves']):
        return True
    return False

###############################################################################
# Utilities: Updates config file before running an experiment
###############################################################################

def perm_to_configuration(perm_id, perm, config_dict):
    for i, o in enumerate(config_dict['hyper_params_to_iter']):
        config_dict[o] = perm[o]
    perm_to_configuration_further_changes(perm_id, config_dict)


def perm_to_configuration_further_changes(perm_id, config_dict):
    config_dict['seed'] = perm_id*3+100
    if config_dict['var_dist'] in ["GaussianWithSampleEval", "DiagonalWithSampleEval", "BlockGaussianWithSampleEval"]:
        config_dict['minibatch_use'] = False
        config_dict['minibatch_size'] = config_dict['N_leaves']
    else:
        if config_dict.get('minibatch_map_use', False):
            config_dict['minibatch_size'] = config_dict['minibatch_map'][config_dict['var_dist']].get(
                                                str(config_dict['N_leaves']), 
                                                config_dict['minibatch_size'])
    if config_dict.get('optimizer_step_drop_count_map_use', False):
        config_dict['optimizer_step_drop_count'] = config_dict['optimizer_step_drop_count_map'].get(
                                                str(config_dict['N_leaves']),
                                                config_dict['optimizer_step_drop_count'])

    if config_dict.get('encoder_map_use', False):
        config_dict['encoder'] = config_dict['encoder_map'].get(
                                                str(config_dict['var_dist']),
                                                config_dict['encoder'])


def n_iter_from_epoch(config_dict):
    assert('minibatch_size' in config_dict.keys())
    assert('n_epoch' in config_dict.keys())
    if config_dict['minibatch_use'] is True:
        n_iter = int(
                    config_dict['n_epoch']
                    * (config_dict['N_leaves'] 
                        // config_dict['minibatch_size']))
    else:
        n_iter = int(config_dict['n_epoch'])
        
    return n_iter


def n_iter_from_dict(config_dict):
    return int(config_dict['n_iter'])


def get_n_iter(config_dict):

    if config_dict.get('fix_epoch', False) is True: 
        return n_iter_from_epoch(config_dict)
    else: 
        return n_iter_from_dict(config_dict)

#####################################################################
#   Utilities: SLURM File
#####################################################################
def config_to_mem(config_dict, experimenter_dict):
    experimenter_dict['mem_per_job'] = '12000'


def config_to_partition(config_dict, experimenter_dict):
    experimenter_dict.update({
        'partition_per_job': (
                    time_to_partition(
                        experimenter_dict['time_per_job']) 
                    if which_server() == "swarm" 
                    else 
                    time_to_partition_gpu(
                        experimenter_dict['time_per_job']))
        })    


def config_to_cores(config_dict, experimenter_dict):
    if "n_cores" in config_dict['hyper_params_to_iter']:
        experimenter_dict['n_cores']=str(config_dict['n_cores'])


def config_to_gpus(config_dict, experimenter_dict):
    if "n_gpus" in config_dict['hyper_params_to_iter']:
        experimenter_dict['n_gpus']=str(config_dict['n_gpus'])


def config_to_time(config_dict, experimenter_dict):
    experimenter_dict['time_per_job'] = ('0-12:00'
                                        if which_server() == "swarm"
                                        else "0-04:00")


def exp_dict_updates_from_config(config_dict, experimenter_dict):
    config_to_mem(config_dict, experimenter_dict)
    config_to_cores(config_dict, experimenter_dict)
    config_to_gpus(config_dict, experimenter_dict)
    config_to_time(config_dict, experimenter_dict)
    config_to_partition(config_dict, experimenter_dict)


def generate_run_command(environment, file, perm_id, uniq_name, config_dict):
    if "run_vi" in file: 
        f =  (
                f"source activate {environment}\n"
                f"python {file} -id {perm_id}"
                f" -uname {uniq_name} -config '{hjson.dumps(config_dict)}'"
                "\nexit"
            )
    elif "monitor_trace" in file:
        f =  (
                f"source activate {environment}\n"
                f"python {file} -id {perm_id}"
                f" -plot {config_dict['monitor_trace_plot']} -swr 1"
                f" -uname {uniq_name} -config '{hjson.dumps(config_dict)}'"
                "\nexit"
            )
    else:
        raise ValueError
    return f 


def time_to_partition(time):
    if ((int(time.split('-')[0])>0) 
        | (int(time.split('-')[1][:2])>12)):
        return  'longq'
    else:
        return  "defq"


def time_to_partition_gpu(time):
    if ((int(time.split('-')[0])>0) 
        | (int(time.split('-')[1][:2])>4)):
        return '1080ti-long' 
    else:
        return "2080ti-short"


#####################################################################
#   Utilities: Configuration File
#####################################################################
def get_config_dirname(uniq_name):
    return "data/experiments/configs/"


def get_config_filename(uniq_name, directory_name):
    return (directory_name+str(uniq_name)+".hjson")

def update_config(config_dict):
    config_dict['n_iter'] = get_n_iter(config_dict)

def get_saved_config(uniq_name, verbose=True):
    config_dict_dirname = get_config_dirname(uniq_name)
    config_dict_filename = get_config_filename(uniq_name,
                                    config_dict_dirname)

    if not os.path.exists(config_dict_filename):
        raise NotImplementedError
    with open(config_dict_filename, 'r') as f: 
        config_dict = hjson.load(f)
    if verbose:
        print(f"Saved config file at: {config_dict_filename}")
    return config_dict


def save_config(config_dict, uniq_name, verbose=True):
    config_dict_dirname = get_config_dirname(uniq_name)
    utils.create_dir(config_dict_dirname)
    config_dict_filename = get_config_filename(uniq_name, config_dict_dirname)

    with open(config_dict_filename, 'w') as f: 
        hjson.dump(config_dict, f)
    if verbose:
        print(f"Saved config file at: {config_dict_filename}")


#####################################################################
#   Utilities: Figures
#####################################################################
def get_figures_dirname(uniq_name):
    return "data/experiments/figures/"


def get_figures_filename(uniq_name, directory_name):
    return (directory_name+str(uniq_name))


#####################################################################
#   Utilities: Results
#####################################################################

def get_results_dirname(config_dict):
    return (
            "data/experiments/"
            + '/'.join([
                          o+'/'+str(config_dict.get(o,"-")) 
                          for o in config_dict.get('hyper_params_to_iter', ['var_dist'])
                          if config_dict.get(o, "-") not in [None, "None"]
                          ])
            ).replace(" ","_")


def get_results_filename(dir_name, token, uname):
    return f"{dir_name}/{token}_{uname}.pkl"

def loader(token, config_dict, uname=None, verbose=False):
    if not uname: uname = config_dict['uname']
    if "test" not in uname:
        dir_name = get_results_dirname(config_dict)
        utils.create_dir(dir_name)
        return utils.load_objects(
                        get_results_filename(
                            dir_name, token, uname),
                        verbose)
    return None

def saver(token, results, config_dict, uname=None, verbose=False):
    if not uname: uname = config_dict['uname']
    if "test" not in uname:
        dir_name = get_results_dirname(config_dict)
        utils.create_dir(dir_name)
        utils.dump_objects(
                        results,
                        get_results_filename(
                            dir_name, token, uname),
                        verbose)

